"""Training loop with curriculum learning for subgraph isomorphism RL."""

import torch
import numpy as np
import networkx as nx
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import logging
from tqdm import tqdm
from collections import deque
import matplotlib.pyplot as plt

from .agent import PPOAgent, PPOConfig
from .environment import SubgraphIsomorphismEnv, RewardParams
from .data import generate_random_graph, create_pattern_target_pair
from .evaluation import evaluate_agent


@dataclass
class CurriculumConfig:
    start_pattern_size: int = 3
    max_pattern_size: int = 10
    start_target_size: int = 5
    max_target_size: int = 20
    size_increment_episodes: int = 1000
    edge_prob_start: float = 0.3
    edge_prob_end: float = 0.7
    success_threshold: float = 0.8
    min_episodes_per_level: int = 500


class Trainer:
    """Trainer for subgraph isomorphism RL agent."""
    
    def __init__(
        self,
        agent: PPOAgent,
        curriculum_config: CurriculumConfig,
        device: str = 'cpu',
        log_interval: int = 100,
        save_interval: int = 1000
    ):
        self.agent = agent
        self.curriculum_config = curriculum_config
        self.device = device
        self.log_interval = log_interval
        self.save_interval = save_interval
        
        # Curriculum state
        self.current_pattern_size = curriculum_config.start_pattern_size
        self.current_target_size = curriculum_config.start_target_size
        self.current_edge_prob = curriculum_config.edge_prob_start
        
        # Training metrics
        self.episode_rewards = deque(maxlen=100)
        self.episode_lengths = deque(maxlen=100)
        self.success_rates = deque(maxlen=100)
        
        # Logging
        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__name__)
    
    def train(
        self,
        total_episodes: int,
        save_path: str = 'checkpoints',
        eval_interval: int = 500,
        eval_episodes: int = 10
    ):
        """Main training loop."""
        self.logger.info(f"Starting training for {total_episodes} episodes")
        
        episode = 0
        best_success_rate = 0.0
        
        with tqdm(total=total_episodes, desc="Training") as pbar:
            while episode < total_episodes:
                # Generate curriculum graphs
                pattern, target = self._generate_curriculum_graphs()
                
                # Create environment
                env = SubgraphIsomorphismEnv(pattern, target)
                
                # Run episode
                episode_reward, episode_length, success = self._run_episode(env)
                
                # Update metrics
                self.episode_rewards.append(episode_reward)
                self.episode_lengths.append(episode_length)
                self.success_rates.append(float(success))
                
                episode += 1
                
                # Update curriculum
                if episode % self.curriculum_config.size_increment_episodes == 0:
                    self._update_curriculum()
                
                # Logging
                if episode % self.log_interval == 0:
                    self._log_progress(episode)
                
                # Evaluation
                if episode % eval_interval == 0:
                    eval_success_rate = self._evaluate(eval_episodes)
                    
                    if eval_success_rate > best_success_rate:
                        best_success_rate = eval_success_rate
                        self.agent.save(f"{save_path}/best_model.pt")
                    
                    self.logger.info(f"Evaluation at episode {episode}: {eval_success_rate:.3f}")
                
                # Save checkpoint
                if episode % self.save_interval == 0:
                    self.agent.save(f"{save_path}/checkpoint_{episode}.pt")
                
                # Update progress bar
                pbar.set_postfix({
                    'reward': f"{np.mean(self.episode_rewards):.2f}",
                    'success': f"{np.mean(self.success_rates):.2f}",
                    'pattern_size': self.current_pattern_size
                })
                pbar.update(1)
                
                # PPO update
                if self.agent.buffer.size >= self.agent.config.batch_size:
                    update_stats = self.agent.update()
                    
                    if episode % self.log_interval == 0:
                        self.logger.info(f"Update stats: {update_stats}")
        
        self.logger.info("Training completed")
        return self.episode_rewards, self.success_rates
    
    def _run_episode(self, env: SubgraphIsomorphismEnv) -> Tuple[float, int, bool]:
        """Run single training episode."""
        observation, _ = env.reset()
        episode_reward = 0.0
        episode_length = 0
        done = False
        
        while not done:
            # Select action
            action, log_prob, value = self.agent.select_action(observation)
            
            # Take step
            next_observation, reward, done, _, _ = env.step(action)
            
            # Store experience
            self.agent.store_experience(
                observation, action, reward, value, log_prob, done
            )
            
            # Update
            observation = next_observation
            episode_reward += reward
            episode_length += 1
        
        # Check if successful
        success = env._is_complete()
        
        return episode_reward, episode_length, success
    
    def _generate_curriculum_graphs(self) -> Tuple[nx.Graph, nx.Graph]:
        """Generate graphs according to current curriculum level."""
        # Generate target graph
        target = generate_random_graph(
            n_nodes=self.current_target_size,
            edge_prob=self.current_edge_prob,
            seed=np.random.randint(0, 10000)
        )
        
        # Create pattern from target
        pattern, target = create_pattern_target_pair(
            target,
            pattern_size=self.current_pattern_size,
            seed=np.random.randint(0, 10000)
        )
        
        return pattern, target
    
    def _update_curriculum(self):
        """Update curriculum difficulty."""
        success_rate = np.mean(self.success_rates) if self.success_rates else 0.0
        
        # Increase difficulty if success rate is high
        if success_rate > self.curriculum_config.success_threshold:
            if self.current_pattern_size < self.curriculum_config.max_pattern_size:
                self.current_pattern_size += 1
                self.logger.info(f"Increased pattern size to {self.current_pattern_size}")
            
            if self.current_target_size < self.curriculum_config.max_target_size:
                self.current_target_size += 1
                self.logger.info(f"Increased target size to {self.current_target_size}")
            
            # Gradually increase edge probability
            edge_prob_range = (
                self.curriculum_config.edge_prob_end - 
                self.curriculum_config.edge_prob_start
            )
            progress = min(1.0, self.current_pattern_size / self.curriculum_config.max_pattern_size)
            self.current_edge_prob = (
                self.curriculum_config.edge_prob_start + 
                progress * edge_prob_range
            )
    
    def _evaluate(self, eval_episodes: int) -> float:
        """Evaluate agent performance."""
        success_count = 0
        
        for _ in range(eval_episodes):
            pattern, target = self._generate_curriculum_graphs()
            env = SubgraphIsomorphismEnv(pattern, target)
            
            observation, _ = env.reset()
            done = False
            
            while not done:
                action, _, _ = self.agent.select_action(observation, deterministic=True)
                observation, _, done, _, _ = env.step(action)
            
            if env._is_complete():
                success_count += 1
        
        return success_count / eval_episodes
    
    def _log_progress(self, episode: int):
        """Log training progress."""
        avg_reward = np.mean(self.episode_rewards) if self.episode_rewards else 0.0
        avg_length = np.mean(self.episode_lengths) if self.episode_lengths else 0.0
        success_rate = np.mean(self.success_rates) if self.success_rates else 0.0
        
        self.logger.info(
            f"Episode {episode}: "
            f"Avg Reward: {avg_reward:.2f}, "
            f"Avg Length: {avg_length:.1f}, "
            f"Success Rate: {success_rate:.2f}, "
            f"Pattern Size: {self.current_pattern_size}, "
            f"Target Size: {self.current_target_size}"
        )
    
    def plot_training_curves(self, save_path: str = None):
        """Plot training curves."""
        fig, axes = plt.subplots(2, 2, figsize=(12, 8))
        
        # Episode rewards
        axes[0, 0].plot(self.episode_rewards)
        axes[0, 0].set_title('Episode Rewards')
        axes[0, 0].set_xlabel('Episode')
        axes[0, 0].set_ylabel('Reward')
        
        # Episode lengths
        axes[0, 1].plot(self.episode_lengths)
        axes[0, 1].set_title('Episode Lengths')
        axes[0, 1].set_xlabel('Episode')
        axes[0, 1].set_ylabel('Steps')
        
        # Success rates
        axes[1, 0].plot(self.success_rates)
        axes[1, 0].set_title('Success Rates')
        axes[1, 0].set_xlabel('Episode')
        axes[1, 0].set_ylabel('Success Rate')
        
        # Training losses
        if self.agent.training_stats['policy_loss']:
            axes[1, 1].plot(self.agent.training_stats['policy_loss'], label='Policy Loss')
            axes[1, 1].plot(self.agent.training_stats['value_loss'], label='Value Loss')
            axes[1, 1].set_title('Training Losses')
            axes[1, 1].set_xlabel('Update')
            axes[1, 1].set_ylabel('Loss')
            axes[1, 1].legend()
        
        plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path)
        else:
            plt.show()


def create_trainer(
    node_features: int = 16,
    hidden_dim: int = 128,
    device: str = 'cpu'
) -> Trainer:
    """Create trainer with default configuration."""
    # PPO configuration
    ppo_config = PPOConfig(
        learning_rate=3e-4,
        clip_ratio=0.2,
        value_coef=0.5,
        entropy_coef=0.01,
        ppo_epochs=4,
        batch_size=64,
        buffer_size=2048
    )
    
    # Create agent
    agent = PPOAgent(
        config=ppo_config,
        node_features=node_features,
        hidden_dim=hidden_dim,
        device=device
    )
    
    # Curriculum configuration
    curriculum_config = CurriculumConfig()
    
    # Create trainer
    trainer = Trainer(
        agent=agent,
        curriculum_config=curriculum_config,
        device=device
    )
    
    return trainer